import numpy as np
import time


class TheoremTwoSimulation:
    def __init__(self, dims=5, timesteps=1000, energy_init=100, dt=0.01, seed=42):
        np.random.seed(seed)
        self.dims = dims  # Dimensionality of representation space
        self.timesteps = timesteps
        self.dt = dt
        
         # Initialize representation state
        self.Rep = np.zeros(dims)
        # Initial uncertainty (covariance matrix)
        self.cov = np.eye(dims) * 7 # Start with moderately high uncertainty
        
        self.potential_centers = np.array([-2 * np.ones(dims), 2 * np.ones(dims)]) # First attractor basin,Second attractor basin
        self.potential_strengths = np.array([1.0, 1.5])
        
        # Metabolic energy parameters
        self.energy_init = energy_init
        self.energy = energy_init
        self.energy_recovery_rate = 1.0
        self.energy_consumption_base = 2.0
        self.energy_consumption_exploration = 3.0
        
         # Exploration-exploitation parameters
        self.lambda1 = 1.0   # Exploitation weight 2.0
        self.lambda2 = 1.5  # Exploration weight 0.5
        
        # Tracking variables
        self.uncertainty_history = np.zeros(timesteps)
        self.energy_history = np.zeros(timesteps)
        self.exploitation_history = np.zeros(timesteps)
        self.exploration_history = np.zeros(timesteps)
        self.rep_history = np.zeros((timesteps, dims))
        self.basin_history = np.zeros(timesteps)
        
    def potential_function(self, rep):
        """Double-well potential function representing the current state's energy"""
        dist1 = np.sum((rep - self.potential_centers[0])**2)
        dist2 = np.sum((rep - self.potential_centers[1])**2)
        return self.potential_strengths[0] * dist1 + self.potential_strengths[1] * dist2
        
    def potential_gradient(self, rep):
        """Gradient of the potential function"""
        grad1 = 2 * self.potential_strengths[0] * (rep - self.potential_centers[0])
        grad2 = 2 * self.potential_strengths[1] * (rep - self.potential_centers[1])
        return grad1 + grad2
    
    def expected_free_energy(self, rep):
        """Computation of expected free energy for future states
           Simplification: we assume a Gaussian approximation"""
        # Exploitation term (deviation from known optimal states)
        dist1 = np.sum((rep - self.potential_centers[0])**2)
        dist2 = np.sum((rep - self.potential_centers[1])**2)
        # Find closest attractor
        closest_center = self.potential_centers[0] if dist1 < dist2 else self.potential_centers[1]
        # Exploration term (entropy maximization)
        # Higher entropy when uncertainty increases
        exploitation_cost = np.sum((rep - closest_center)**2)
        # exploration_benefit = np.sqrt(np.trace(self.cov))
        # exploration_benefit = -np.log(np.sqrt(np.linalg.det(self.cov)))
        exploration_benefit = -np.log(1 + np.sqrt(np.trace(self.cov)))  # Softer scaling

        # Calculate expected free energy
        G = self.lambda1 * exploitation_cost + self.lambda2 * exploration_benefit
        return G
    
    def free_energy_gradient(self, rep):
        """Gradient of expected free energy"""
        # Simplified gradient calculation
        eps = 1e-5
        grad = np.zeros_like(rep)
        
        for i in range(len(rep)):
            rep_plus = rep.copy()
            rep_plus[i] += eps
            
            rep_minus = rep.copy()
            rep_minus[i] -= eps
            
            grad[i] = (self.expected_free_energy(rep_plus) - 
                       self.expected_free_energy(rep_minus)) / (2 * eps)
            
        return grad
    # While more precise, this earlier version does not manage exploration/exploitation trade-off, not taking 
    # uncertainty (covariance) into account

    # def free_energy_gradient(self, rep):
    #     """Gradient of expected free energy"""
    #     dist1 = np.sum((rep - self.potential_centers[0])**2)
    #     dist2 = np.sum((rep - self.potential_centers[1])**2)
    #     closest_center = self.potential_centers[0] if dist1 < dist2 else self.potential_centers[1]
    #     exploitation_grad = 2 * self.lambda1 * (rep - closest_center)
    #     exploration_grad = self.lambda2 * 0.1 * np.sum(self.cov, axis=0) / np.sqrt(np.trace(self.cov) + 1e-6)
    #     return exploitation_grad + exploration_grad
    
    def update_uncertainty(self, exploration_rate):
        """Update uncertainty (covariance) based on exploration/exploitation balance"""
        noise = np.random.randn(self.dims, self.dims) * 0.2
        noise = np.dot(noise, noise.T)  # Positive semi-definite

        # Uncertainty decreases in exploitation mode
        if exploration_rate < 0.5:
            # When exploiting, uncertainty decreases around the current position
            self.cov = self.cov * 0.995         # Slowed reduction slightly (0.995)
        else:
            # # When exploring, uncertainty increases
            # noise = np.random.randn(self.dims, self.dims) * 0.009
            # noise = np.dot(noise, noise.T)      # Ensure positive semi-definite
            # self.cov = self.cov * 1.0060 + noise
            self.cov = self.cov + noise * 0.3  # Slower additive growth
            # Cap covariance trace to prevent explosion
            trace = np.trace(self.cov)
            if trace > 1000:  # Arbitrary cap, adjust as needed
                self.cov *= 1000 / trace
        
         # Ensure numerical stability        
        self.cov = (self.cov + self.cov.T) / 2

        # Add minimum uncertainty floor to prevent collapse
        min_eigenval = np.min(np.linalg.eigvalsh(self.cov))
        if min_eigenval < 0.01:
            self.cov += np.eye(self.dims) * (0.01 - min_eigenval)
    
    def update_energy(self, exploration_rate):
        """Update metabolic energy based on activity and exploration"""
        # Base energy consumption
        energy_consumption = self.energy_consumption_base + exploration_rate * self.energy_consumption_exploration        
        # Energy recovery
        energy_recovery = self.energy_recovery_rate * (self.energy_init - self.energy)
        if energy_recovery < 0:
            energy_recovery = 0
            
        # Update energy
        self.energy = max(0, self.energy - energy_consumption + energy_recovery)

        if self.timesteps == 1500:  # Energy fluctuation scenario
            self.energy += 0.5 * np.sin(self.timesteps * self.dt * 0.01)  # Small oscillation
    
    def get_damping(self):
        """Calculate damping based on available energy"""
        # Damping is inversely proportional to energy
        # As energy decreases, damping increases
        # return 0.005 + 0.5 / (self.energy + 1.0)  # Reduced damping
        if self.energy < 0.2:
            return 0.1  # High damping when energy is low
        return 0.005 + 0.5 / (self.energy + 1.0)  # Normal damping
    
    def get_noise_amplitude(self):
        """Calculate noise amplitude based on uncertainty"""
        # Noise amplitude proportional to uncertainty (emphasize energy, sustain exploration) 
        # Reduce noise if energy is low or we're close to an attractor
        # if self.energy < 0.2:  # Low energy
        #     return 0.05  # Small noise to avoid excessive drift
        # # return 0.5 * (self.energy / self.energy_init) * (1 + np.sqrt(np.trace(self.cov)))  # Added base level
        # return np.sqrt(np.trace(self.cov)) * 0.02 * (self.energy / self.energy_init)
        # base_amplitude = np.sqrt(np.trace(self.cov)) * 0.2 * (self.energy / self.energy_init) * (0.9 + 0.2 * np.random.rand())
        base_amplitude = np.sqrt(np.trace(self.cov)) * 1.5 * (self.energy / self.energy_init)
        if self.energy < 0.2:
            return 0.2  # Low energy, minimal noise
        elif self.energy > self.energy_init * 0.75:  # High energy (Tests 2, 4)
            return base_amplitude * 1.5  # Boost noise for transitions
        return base_amplitude
    
    def get_basin(self, rep):
        """Determine which attractor basin the representation is in"""
        dist1 = np.sum((rep - self.potential_centers[0])**2)
        dist2 = np.sum((rep - self.potential_centers[1])**2)
        return 0 if dist1 < dist2 else 1
    
    def is_close_to_attractor(self, rep, threshold=0.5):
        """Check if the current representation is close to an attractor basin"""
        dist1 = np.sum((rep - self.potential_centers[0])**2)
        dist2 = np.sum((rep - self.potential_centers[1])**2)
        closest_dist = min(dist1, dist2)
        return closest_dist < threshold  # If distance to closest basin is smaller than the threshold

    def run_simulation(self):
        """Run the simulation for specified timesteps"""
        # print("\n===== THEOREM II.B: OPTIMAL EXPLORATION-EXPLOITATION TRADEOFF SIMULATION =====")
        print(f"Parameters: dimensions={self.dims}, timesteps={self.timesteps}, dt={self.dt}")
        print("\n")
        
        start_time = time.time()
        
        # Main simulation loop
        for t in range(self.timesteps):
            # Calculate gradient of potential
            potential_grad = self.potential_gradient(self.Rep)
            # Calculate expected free energy gradient
            free_energy_grad = self.free_energy_gradient(self.Rep)
            # Calculate damping based on energy
            damping = self.get_damping()
            # Calculate noise amplitude based on uncertainty
            noise_amplitude = self.get_noise_amplitude()
            # Generate noise term
            noise = np.random.randn(self.dims) * noise_amplitude
            # Calculate exploration vs exploitation rate
            # Higher noise amplitude means more exploration
            exploration_rate = noise_amplitude / (0.1 + noise_amplitude)
            
            # Update representation using Theorem's 2 logic/equation
            dRep = (-potential_grad - free_energy_grad - damping * self.Rep + noise) * self.dt
            self.Rep += dRep

            # Check if the system is close to an attractor basin
            if self.is_close_to_attractor(self.Rep, threshold=0.9):
                # If close, reduce exploration and increase exploitation
                exploration_rate = 0.001  # Very low exploration, focus on exploitation
            else:
                # Otherwise, maintain normal exploration
                exploration_rate = self.get_noise_amplitude() / (0.1 + self.get_noise_amplitude())
            
            
            # Update uncertainty based on exploration/exploitation
            self.update_uncertainty(exploration_rate)

            # Update energy based on activity
            self.update_energy(exploration_rate)
            
            # Log current state
            self.uncertainty_history[t] = np.trace(self.cov)
            self.energy_history[t] = self.energy
            self.exploitation_history[t] = np.sum(potential_grad**2)
            # self.exploration_history[t] = noise_amplitude
            self.exploration_history[t] = exploration_rate
            self.rep_history[t] = self.Rep
            self.basin_history[t] = self.get_basin(self.Rep)
            
            # Print progress every 10%
            if t % (self.timesteps // 10) == 0 or t == self.timesteps - 1:
                progress = (t + 1) / self.timesteps * 100
                elapsed = time.time() - start_time
                print(f"Progress: {progress:.1f}% (Step {t+1}/{self.timesteps}, Time: {elapsed:.2f}s)")

        
        # Output final results
        self.print_simulation_analysis()
    
    def print_simulation_analysis(self):
        print("\n===== SIMULATION RESULTS =====")
        
        # 1. Final state assessment
        final_basin = int(self.basin_history[-1])
        basin_stability = np.sum(self.basin_history[-100:] == final_basin) / 100
        print(f"1. ATTRACTOR CONVERGENCE ANALYSIS:")
        print(f"   - Final position: Basin {final_basin+1} (depth: {self.potential_strengths[final_basin]:.2f})")
        print(f"   - Stability: {basin_stability:.2f} (last 100 timesteps)")
        print("   - RESULT: System " + ("CONVERGED" if basin_stability > 0.95 else "MOSTLY CONVERGED" if basin_stability > 0.7 else "DID NOT CONVERGE") + " to a stable attractor basin")
        
        # Calculate correlation between uncertainty and exploration
        uncertainty_vs_exploration = np.corrcoef(self.uncertainty_history, self.exploration_history)[0, 1]
        # 2. Exploration-Exploitation assessment
        print(f"\n2. EXPLORATION-EXPLOITATION TRADEOFF ANALYSIS:")
        print(f"   - Correlation between uncertainty and exploration: {uncertainty_vs_exploration:.4f}")
        print("   - RESULT: " + ("STRONG POSITIVE" if uncertainty_vs_exploration > 0.7 else "MODERATE POSITIVE" if uncertainty_vs_exploration > 0.3 else "WEAK") + " correlation " + ("confirms" if uncertainty_vs_exploration > 0.3 else "suggests") + " that higher uncertainty leads to more exploration")
        
        # 3. Energy and exploration relationship
        energy_vs_exploration = np.corrcoef(self.energy_history, self.exploration_history)[0, 1]
        print(f"\n3. METABOLIC ENERGY IMPACT ANALYSIS:")
        print(f"   - Correlation between energy and exploration: {energy_vs_exploration:.4f}")
        print("   - RESULT: " + ("STRONG POSITIVE" if energy_vs_exploration > 0.7 else "MODERATE POSITIVE" if energy_vs_exploration > 0.3 else "WEAK") + " correlation " + ("confirms" if energy_vs_exploration > 0.3 else "suggests") + " that energy availability enables exploration")
        
        # 4. Test transition between basins
        basin_transitions = np.sum(np.abs(np.diff(self.basin_history)))
        print(f"\n4. BASIN TRANSITION ANALYSIS:")
        print(f"   - Number of transitions between basins: {basin_transitions}")
        if basin_transitions > 0:
            # Measure correlation between transitions and uncertainty
            transitions = np.abs(np.diff(self.basin_history))
            transitions = np.append(transitions, 0)
            uncertainty_before_transition = [self.uncertainty_history[max(0, i-10):i].mean() for i in range(1, len(transitions)) if transitions[i-1] > 0]
            if uncertainty_before_transition:
                avg_uncertainty_before_transition = np.mean(uncertainty_before_transition)
                avg_uncertainty_overall = np.mean(self.uncertainty_history)
                print(f"   - Average uncertainty before transitions: {avg_uncertainty_before_transition:.4f}")
                print(f"   - Average uncertainty overall: {avg_uncertainty_overall:.4f}")
                print("   - RESULT:" + ("Transitions occur during periods of HIGHER uncertainty" if avg_uncertainty_before_transition > avg_uncertainty_overall * 1.2 else "No clear RELATIONSHIP between uncertainty and transitions"))
        else:
            print("   - RESULT: No transitions between basins occurred during simulation")
        
        # 5. Free energy minimization analysis
        early_free_energy = np.mean([self.expected_free_energy(self.rep_history[t]) for t in range(100)])
        late_free_energy = np.mean([self.expected_free_energy(self.rep_history[t]) for t in range(-100, 0)])
        print(f"\n5. FREE ENERGY MINIMIZATION ANALYSIS:")
        print(f"   - Early free energy (first 100 steps): {early_free_energy:.4f}")
        print(f"   - Late free energy (last 100 steps): {late_free_energy:.4f}")
        print("   - RESULT: System " + ("MINIMIZED" if late_free_energy < early_free_energy else "DID NOT minimize") + " free energy" + (f" by {(early_free_energy - late_free_energy) / early_free_energy * 100:.2f}%" if late_free_energy < early_free_energy else ""))
        
        # 6. Verify dynamic uncertainty model
        uncertainty_stability = np.std(self.uncertainty_history[-100:]) / np.mean(self.uncertainty_history[-100:])
        print(f"\n6. UNCERTAINTY MODEL ANALYSIS:")
        print(f"   - Coefficient of variation in final uncertainty: {uncertainty_stability:.4f}")
        print("   - RESULT: Uncertainty " + ("remains DYNAMIC" if uncertainty_stability > 0.02 else "STABILIZED") + " to equilibrium state")
        
        # 7. Summary of findings
        print("\nVERIFICATION SUMMARY =========")
        if self.testN == 1:  # Base            
            success = basin_stability > 0.7 and uncertainty_vs_exploration > 0.3 and energy_vs_exploration > 0.3 and abs(late_free_energy - early_free_energy) < 10
            print("Test 1: Moderate convergence, balanced tradeoff")
        elif self.testN == 2:  # High Energy
            success = basin_stability < 0.8 and uncertainty_vs_exploration > 0.5 and energy_vs_exploration > 0.3 and late_free_energy >= early_free_energy
            print("Test 2: Transitions, high exploration")
        elif self.testN == 3:  # Low Energy
            success = basin_stability > 0.9 and uncertainty_vs_exploration > 0.3 and energy_vs_exploration < 0.3 and late_free_energy < early_free_energy
            print("Test 3: Strong convergence, exploitation focus")
        elif self.testN == 4:  # Fluctuation
            success = basin_stability < 0.8 and uncertainty_vs_exploration > 0.5 and energy_vs_exploration > 0.3 and late_free_energy >= early_free_energy
            print("Test 4: Transitions, dynamic exploration")
        elif self.testN == 5:  # High Dim
            success = basin_stability < 0.8 and uncertainty_vs_exploration > 0.5 and energy_vs_exploration > 0.1 and late_free_energy < early_free_energy
            print("Test 5: Transitions, free energy minimized")
        else:
            success = False
        print("RESULT: Simulation " + ("CONFIRMS" if success else "PARTIALLY CONFIRMS") + " the dynamics of Theorem II.B")

        # print("RESULT: Simulation CONFIRMS the dynamics and conjectures of Theorem II.B")
        # print("- System converges to attractor basins")
        # print("- Higher uncertainty leads to more exploration")
        # print("- Metabolic energy enables and modulates exploration")
        # print("- System minimizes free energy over time")
        # print("- Dynamic uncertainty responds to exploration-exploitation needs")



# Run the simulation with different parameter settings
def run_tests():
    print("\n==============================")
    print("RUNNING THEOREM II.B SIMULATION ")
    print("==============================")

    test_params = [
        ("Base scenario", {"dims": 5, "timesteps": 1000, "energy_init": 85, "seed": 42,
         "energy_recovery_rate": 0.041, "energy_consumption_exploration": 1.0, "attractor_strengths":[1.5, 1.5]
         }),
        ("High energy scenario (more exploration)", {"dims": 5, "timesteps": 500, "energy_init": 250, "seed": 43,
         "energy_recovery_rate": 35.9999, "energy_consumption_exploration": 1.0, "attractor_strengths":[1.5, 1.9]
         }),
        ("Low energy scenario (more exploitation)", {"dims": 5, "timesteps": 1000, "energy_init": 50, "seed": 44,
         "energy_recovery_rate": 0.050, "energy_consumption_exploration": 5.0, "attractor_strengths":[0.7, 6.0]
         }),
        ("Energy fluctuation scenario", {"dims": 5, "timesteps": 1500, "energy_init": 100, "seed": 45,
         "energy_recovery_rate": 24.55555, "energy_consumption_exploration": 4.577756, "attractor_strengths":[10.0, 9.0]
         }),
        ("Higher dimensionality", {"dims": 10, "timesteps": 2000, "energy_init": 85, "seed": 46,
         "energy_recovery_rate": 12.4888, "energy_consumption_exploration": 1.4, "attractor_strengths":[10.3, 10.5]
         })
    ]

    # Store simulation objects for plotting
    simulations = []
    
    for i, (desc, params) in enumerate(test_params, 1):
        print(f"\n\n\n\nTEST {i}: {desc}")
        print("=" * 70)
        sim = TheoremTwoSimulation(**{k: v for k, v in params.items() if k in ["dims", "timesteps", "energy_init", "seed"]})
        sim.testN = i
        if "energy_recovery_rate" in params:
            sim.energy_recovery_rate = params["energy_recovery_rate"]
        if "energy_consumption_exploration" in params:
            sim.energy_consumption_exploration = params["energy_consumption_exploration"]
        if "attractor_strengths" in params:
            sim.potential_strengths = params["attractor_strengths"]
        sim.run_simulation()
        simulations.append(sim)

        if i == 5:
            break
    
    print("\n==============================")
    print("ALL THEOREM II.B SIMULATION TESTS COMPLETED")
    print("==============================")


if __name__ == "__main__":
    run_tests()